{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "50a965a4-79a6-4a1b-96be-f8c857c16d46",
   "metadata": {},
   "source": [
    "# Tile Language in BitBLAS\n",
    "\n",
    "More flexiable, More Efficient Tile Programming Languange compared with Triton\n",
    "\n",
    "## Features\n",
    "\n",
    "- **Simplified Syntax**: Write GPU kernels with a more straightforward and expressive syntax.\n",
    "- **High Performance**: Achieve performance comparable to manually optimized implementations.\n",
    "- **Advanced Operations**: Support for complex operations like convolutions, flash-attention, and normalizations.\n",
    "- **Compatibility**: Works with modern CUDA architectures.\n",
    "\n",
    "## OP Examples\n",
    "\n",
    "- [Matrix Multiplication](#quick-start)\n",
    "- [Flash Attention](#flash-attention)\n",
    "- [Dequantization GEMM](#dequantization-gemm)\n",
    "- [RetNet](#retina-net)\n",
    "- [MAMBA](#mamba)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ff371661-da38-4b37-9e77-330d812aa7e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import Tile Language from bitblas\n",
    "from bitblas import tvm as tvm\n",
    "from tvm import tl\n",
    "import tvm.tl.language as T"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb42b2fa-0d66-4f22-8bd9-7561dfdde15e",
   "metadata": {},
   "source": [
    "## Get Started with a GEMM Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9af2499b-bca8-4ab0-bb2a-4926de21fde1",
   "metadata": {},
   "outputs": [],
   "source": [
    "M = N = K = 256\n",
    "\n",
    "A_shape = (M, K)\n",
    "B_shape = (N, K)\n",
    "C_shape = (M, N)\n",
    "in_dtype = out_dtype = accum_dtype = \"float16\"\n",
    "\n",
    "block_M = block_N = 128\n",
    "block_K = 32\n",
    "threads = 128\n",
    "num_stages = 2\n",
    "\n",
    "A_shared_shape = (block_M, block_K)\n",
    "B_shared_shape = (block_N, block_K)\n",
    "@T.prim_func\n",
    "def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer(\n",
    "    (M, N), out_dtype)):\n",
    "    with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):\n",
    "        A_shared = T.alloc_shared(A_shared_shape, in_dtype)\n",
    "        B_shared = T.alloc_shared(B_shared_shape, in_dtype)\n",
    "        C_local = T.alloc_fragment((block_M, block_N), accum_dtype)\n",
    "        T.clear(C_local)\n",
    "        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):\n",
    "            T.copy(A[by * block_M, k * block_K], A_shared)\n",
    "            T.copy(B[bx * block_N, k * block_K], B_shared)\n",
    "            T.gemm(A_shared, B_shared, C_local, transpose_A=False, transpose_B=True)\n",
    "        T.copy(C_local, C[by * block_M, bx * block_N])\n",
    "\n",
    "func = main"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5503e7a5-3ccb-4c7b-87a2-d6a9d66dea9e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "#include <tl_templates/cuda/gemm.h>\n",
      "#include <tl_templates/cuda/copy.h>\n",
      "#include <tl_templates/cuda/reduce.h>\n",
      "#include <tl_templates/cuda/ldsm.h>\n",
      "#include <tl_templates/cuda/threadblock_swizzle.h>\n",
      "\n",
      "extern \"C\" __global__ void __launch_bounds__(128) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) {\n",
      "  extern __shared__ __align__(1024) uchar buf_dyn_shmem[];\n",
      "  half_t C_local[128];\n",
      "  #pragma unroll\n",
      "  for (int i = 0; i < 64; ++i) {\n",
      "    *(uint1*)(C_local + (i * 2)) = make_uint1(__pack_half2(half_t(0.000000e+00f), half_t(0.000000e+00f)));\n",
      "  }\n",
      "  #pragma unroll\n",
      "  for (int i_1 = 0; i_1 < 4; ++i_1) {\n",
      "    tl::cp_async_gs<16>(buf_dyn_shmem+((((i_1 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)), A+((((((int)blockIdx.y) * 32768) + (i_1 * 8192)) + ((((int)threadIdx.x) >> 2) * 256)) + ((((int)threadIdx.x) & 3) * 8)));\n",
      "  }\n",
      "  #pragma unroll\n",
      "  for (int i_2 = 0; i_2 < 4; ++i_2) {\n",
      "    tl::cp_async_gs<16>(buf_dyn_shmem+(((((i_2 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 16384), B+((((((int)blockIdx.x) * 32768) + (i_2 * 8192)) + ((((int)threadIdx.x) >> 2) * 256)) + ((((int)threadIdx.x) & 3) * 8)));\n",
      "  }\n",
      "  tl::cp_async_commit();\n",
      "  for (int k = 0; k < 7; ++k) {\n",
      "    #pragma unroll\n",
      "    for (int i_3 = 0; i_3 < 4; ++i_3) {\n",
      "      tl::cp_async_gs<16>(buf_dyn_shmem+(((((((k + 1) & 1) * 8192) + (i_3 * 2048)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)), A+((((((((int)blockIdx.y) * 32768) + (i_3 * 8192)) + ((((int)threadIdx.x) >> 2) * 256)) + (k * 32)) + ((((int)threadIdx.x) & 3) * 8)) + 32));\n",
      "    }\n",
      "    #pragma unroll\n",
      "    for (int i_4 = 0; i_4 < 4; ++i_4) {\n",
      "      tl::cp_async_gs<16>(buf_dyn_shmem+((((((((k + 1) & 1) * 8192) + (i_4 * 2048)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 16384), B+((((((((int)blockIdx.x) * 32768) + (i_4 * 8192)) + ((((int)threadIdx.x) >> 2) * 256)) + (k * 32)) + ((((int)threadIdx.x) & 3) * 8)) + 32));\n",
      "    }\n",
      "    tl::cp_async_commit();\n",
      "    tl::cp_async_wait<1>();\n",
      "    __syncthreads();\n",
      "    tl::gemm_ss<128, 128, 32, 2, 2, 0, 1>((&(((half_t*)buf_dyn_shmem)[((k & 1) * 4096)])), (&(((half_t*)buf_dyn_shmem)[(((k & 1) * 4096) + 8192)])), (&(C_local[0])));\n",
      "  }\n",
      "  tl::cp_async_wait<0>();\n",
      "  __syncthreads();\n",
      "  tl::gemm_ss<128, 128, 32, 2, 2, 0, 1>((&(((half_t*)buf_dyn_shmem)[4096])), (&(((half_t*)buf_dyn_shmem)[12288])), (&(C_local[0])));\n",
      "  #pragma unroll\n",
      "  for (int i_5 = 0; i_5 < 64; ++i_5) {\n",
      "    *(uint1*)(C + (((((((((((int)blockIdx.y) * 32768) + (((i_5 & 7) >> 1) * 8192)) + (((((int)threadIdx.x) & 63) >> 5) * 4096)) + ((i_5 & 1) * 2048)) + (((((int)threadIdx.x) & 31) >> 2) * 256)) + (((int)blockIdx.x) * 128)) + ((i_5 >> 3) * 16)) + ((((int)threadIdx.x) >> 6) * 8)) + ((((int)threadIdx.x) & 3) * 2))) = *(uint1*)(C_local + (i_5 * 2));\n",
      "  }\n",
      "}\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "rt_mod, params = tl.lower(func)\n",
    "print(rt_mod.imported_modules[0].get_source())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "212dfd69-b7d6-4a37-a4f3-07241c18488e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Assert Pass\n"
     ]
    }
   ],
   "source": [
    "mod = tl.Profiler(rt_mod, params, [2], tl.TensorSupplyType.Integer)\n",
    "\n",
    "def ref_program(A, B):\n",
    "    import torch\n",
    "    B = B.T\n",
    "    C = torch.matmul(A.to(torch.float), B.to(torch.float))\n",
    "    C = C.to(torch.__getattribute__(out_dtype))\n",
    "    return C\n",
    "\n",
    "mod.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)\n",
    "print(\"Assert Pass\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b4fd8fd-c2ca-4740-8c14-e3646993c9b4",
   "metadata": {},
   "source": [
    "## Manipulate Data Layout and Pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aec3ac4e-385f-44cb-b439-2357901e1d86",
   "metadata": {},
   "source": [
    "TL also provide interface for users to manupulate the memory layout, pipeline and enable rasterization for better L2 Cache Locality. Here is an example of how to use the memory layout and rasterization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5da51fe1-e70d-447a-9097-6df9af6e16d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.\n"
     ]
    }
   ],
   "source": [
    "def matmul(M, N, K, block_M, block_N, block_K, dtype=\"float16\", accum_dtype=\"float\"):\n",
    "    @T.prim_func\n",
    "    def main(\n",
    "        A: T.Buffer((M, K), dtype),\n",
    "        B: T.Buffer((K, N), dtype),\n",
    "        C: T.Buffer((M, N), dtype),\n",
    "    ):\n",
    "        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):\n",
    "            A_shared = T.alloc_shared((block_M, block_K), dtype)\n",
    "            B_shared = T.alloc_shared((block_K, block_N), dtype)\n",
    "            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)\n",
    "\n",
    "            \n",
    "            # Apply memory layout optimizations\n",
    "            # Or you can define your own memory layout\n",
    "            T.annotate_layout({\n",
    "                A_shared: make_swizzle_layout(A_shared),\n",
    "                B_shared: make_swizzle_layout(B_shared),\n",
    "            })\n",
    "\n",
    "            # Enable rasterization for better L2 Cache Locality\n",
    "            T.use_swizzle(panel_size=10, enable=enable_rasterization)\n",
    "\n",
    "            # Clear the local buffer\n",
    "            T.clear(C_local)\n",
    "\n",
    "            # Auto pipeline the computation\n",
    "            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n",
    "                T.copy(A[by * block_M, k * block_K], A_shared)\n",
    "\n",
    "                # Instead of using\n",
    "                # T.copy(B[k * block_K, bx * block_N], B_shared)\n",
    "                # we can also use Parallel to auto map the thread\n",
    "                # bindings and vectorize the copy operation.\n",
    "                for k, j in T.Parallel(block_K, block_N):\n",
    "                    B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]\n",
    "\n",
    "                T.gemm(A_shared, B_shared, C_local)\n",
    "\n",
    "            T.copy(C_local, C[by * block_M, bx * block_N])\n",
    "\n",
    "    return main"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e81ffc7-159a-4ae1-8719-2a0c03142f22",
   "metadata": {},
   "source": [
    "## Implement Dequantize GEMM with simple Syntax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0bae1971-9aa5-4db5-8a3f-f35aa495e4d8",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'T' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[2], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;129m@T\u001b[39m\u001b[38;5;241m.\u001b[39mprim_func\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdequant_matmul\u001b[39m(\n\u001b[1;32m      3\u001b[0m     A: T\u001b[38;5;241m.\u001b[39mBuffer(A_shape, in_dtype),\n\u001b[1;32m      4\u001b[0m     B: T\u001b[38;5;241m.\u001b[39mBuffer(B_shape, storage_dtype),\n\u001b[1;32m      5\u001b[0m     Ct: T\u001b[38;5;241m.\u001b[39mBuffer((N, M), out_dtype),\n\u001b[1;32m      6\u001b[0m ):\n\u001b[1;32m      7\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;241m.\u001b[39mKernel(T\u001b[38;5;241m.\u001b[39mceildiv(N, block_N), T\u001b[38;5;241m.\u001b[39mceildiv(M, block_M), threads\u001b[38;5;241m=\u001b[39mthreads) \u001b[38;5;28;01mas\u001b[39;00m (bx, by):\n\u001b[1;32m      8\u001b[0m         A_shared \u001b[38;5;241m=\u001b[39m T\u001b[38;5;241m.\u001b[39malloc_shared(A_shared_shape, in_dtype)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'T' is not defined"
     ]
    }
   ],
   "source": [
    "@T.prim_func\n",
    "def dequant_matmul(\n",
    "    A: T.Buffer(A_shape, in_dtype),\n",
    "    B: T.Buffer(B_shape, storage_dtype),\n",
    "    Ct: T.Buffer((N, M), out_dtype),\n",
    "):\n",
    "    with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):\n",
    "        A_shared = T.alloc_shared(A_shared_shape, in_dtype)\n",
    "        B_shared = T.alloc_shared(B_shared_shape, storage_dtype)\n",
    "        B_local = T.alloc_fragment(B_shared_shape, storage_dtype)\n",
    "        B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)\n",
    "        Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)\n",
    "\n",
    "        T.clear(Ct_local)\n",
    "        for k in T.Pipelined(\n",
    "            T.ceildiv(K, block_K), \n",
    "            num_stages=num_stages\n",
    "        ):\n",
    "            T.copy(A[by * block_M, k * block_K], A_shared)\n",
    "            T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)\n",
    "            T.copy(B_shared, B_local)\n",
    "            for i, j in T.Parallel(block_N, block_K):\n",
    "                B_dequantize_local[i, j] = _tir_packed_to_unsigned_convert(\"int\", 8)(\n",
    "                    num_bits,\n",
    "                    B_local[i, j // 2],\n",
    "                    j % 2,\n",
    "                    dtype=in_dtype,\n",
    "                )\n",
    "            T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True)\n",
    "        T.copy(Ct_local, Ct[bx * block_N, by * block_M])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48f49319-ad63-4673-8130-b010dc8ba22e",
   "metadata": {},
   "source": [
    "## If you want fine-grained control over dequantization at the thread leve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7e821e6-4517-47b7-8c68-b145eb513d99",
   "metadata": {},
   "outputs": [],
   "source": [
    "@T.prim_func\n",
    "def main(\n",
    "        A: T.Buffer(A_shape, in_dtype),\n",
    "        B: T.Buffer(B_shape, storage_dtype),\n",
    "        C: T.Buffer((M, N), out_dtype),\n",
    "):\n",
    "    with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):\n",
    "        A_shared = T.alloc_shared(A_shared_shape, in_dtype)\n",
    "        B_shared = T.alloc_shared(B_shared_shape, storage_dtype)\n",
    "        B_local = T.alloc_local([local_size_compressed], storage_dtype)\n",
    "        B_dequantize_local = T.alloc_local([local_size], in_dtype)\n",
    "        B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)\n",
    "        C_local = T.alloc_fragment((block_M, block_N), accum_dtype)\n",
    "\n",
    "        tx = T.thread_binding(0, threads, thread=\"threadIdx.x\")\n",
    "\n",
    "        T.clear(C_local)\n",
    "        for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):\n",
    "            T.copy(A[by * block_M, k * block_K], A_shared)\n",
    "            T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)\n",
    "\n",
    "            for i in T.serial(block_N * block_K // num_elems_per_byte //\n",
    "                              (threads * local_size_compressed)):\n",
    "                for v in T.vectorized(0, local_size_compressed):\n",
    "                    index = i * threads * local_size_compressed + tx * local_size_compressed + v\n",
    "                    vi = index // (block_K // num_elems_per_byte)\n",
    "                    vj = index % (block_K // num_elems_per_byte)\n",
    "                    B_local[v] = B_shared[vi, vj]\n",
    "                for v in T.serial(0, local_size):\n",
    "                    B_dequantize_local[v] = _tir_packed_to_unsigned_convert(\n",
    "                        storage_type, storage_nbit)(\n",
    "                            num_bits,\n",
    "                            B_local[v // num_elems_per_byte],\n",
    "                            v % num_elems_per_byte,\n",
    "                            dtype=in_dtype,\n",
    "                        )\n",
    "                for v in T.vectorized(0, local_size):\n",
    "                    index = i * threads * local_size + tx * local_size + v\n",
    "                    vi = index // block_K\n",
    "                    vj = index % block_K\n",
    "                    B_dequantize_shared[vi, vj] = B_dequantize_local[v]\n",
    "\n",
    "            T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)\n",
    "\n",
    "        T.copy(C_local, C[by * block_M, bx * block_N])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37d65188-ce68-4056-bbc3-eaa8ab3b2316",
   "metadata": {},
   "source": [
    "## Flash Attention V3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72f095d2-9978-43c4-ab63-21264e97adb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "@T.prim_func\n",
    "def flash_attention_v3(\n",
    "    Q: T.Buffer(shape, dtype),\n",
    "    K: T.Buffer(shape, dtype),\n",
    "    V: T.Buffer(shape, dtype),\n",
    "    Output: T.Buffer(shape, dtype),\n",
    "):\n",
    "    with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=thread_num) as (bx, by, bz):\n",
    "        Q_shared = T.alloc_shared([block_M, dim], dtype)\n",
    "        K_shared = T.alloc_shared([block_N, dim], dtype)\n",
    "        V_shared = T.alloc_shared([block_N, dim], dtype)\n",
    "        acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)\n",
    "        acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)\n",
    "        acc_o = T.alloc_fragment([block_M, dim], accum_dtype)\n",
    "        scores_max = T.alloc_fragment([block_M], accum_dtype)\n",
    "        scores_max_prev = T.alloc_fragment([block_M], accum_dtype)\n",
    "        scores_scale = T.alloc_fragment([block_M], accum_dtype)\n",
    "        scores_sum = T.alloc_fragment([block_M], accum_dtype)\n",
    "        logsum = T.alloc_fragment([block_M], accum_dtype)\n",
    "\n",
    "        T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)})\n",
    "        T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared)\n",
    "        T.fill(acc_o, 0)\n",
    "        T.fill(logsum, 0)\n",
    "        T.fill(scores_max, -T.infinity(accum_dtype))\n",
    "        loop_range = (\n",
    "            T.ceildiv((bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N)\n",
    "        )\n",
    "        for k in T.Pipelined(loop_range, num_stages=num_stages):\n",
    "            T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared)\n",
    "            if is_casual:\n",
    "                for i, j in T.Parallel(block_M, block_N):\n",
    "                    acc_s[i, j] = T.if_then_else(\n",
    "                        bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)\n",
    "                    )\n",
    "            else:\n",
    "                T.clear(acc_s)\n",
    "            T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)\n",
    "            T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared)\n",
    "            for i, j in T.Parallel(block_M, dim):\n",
    "                acc_s[i, j] *= scale\n",
    "            T.copy(scores_max, scores_max_prev)\n",
    "            T.fill(scores_max, -T.infinity(accum_dtype))\n",
    "            T.reduce_max(acc_s, scores_max, dim=1, clear=False)\n",
    "            for i in T.Parallel(block_M):\n",
    "                scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i])\n",
    "            for i, j in T.Parallel(block_M, dim):\n",
    "                acc_o[i, j] *= scores_scale[i]\n",
    "            for i, j in T.Parallel(block_M, block_N):\n",
    "                acc_s[i, j] = T.exp2(acc_s[i, j] - scores_max[i])\n",
    "            T.copy(acc_s, acc_s_cast)\n",
    "            T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)\n",
    "            T.reduce_sum(acc_s, scores_sum, dim=1)\n",
    "            for i in T.Parallel(block_M):\n",
    "                logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]\n",
    "        for i, j in T.Parallel(block_M, dim):\n",
    "            acc_o[i, j] /= logsum[i]\n",
    "        T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
